import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import time
import gym
import os
import copy
from itertools import chain
from tensordict import TensorDict
from einops import rearrange

from model.utils.storage import build_storage
from model.environment.wrapper import make_env
from model.criterion.quantile_loss import QuantileLoss


class PPOTrainer():
    def __init__(self,
                 config=None,
                 data_values=None,
                 train_environment=None,
                 valid_environment=None,
                 test_environment=None,
                 agent=None,
                 logger=None,
                 writer=None,
                 wandb=None,
                 device=None,
                 **kwargs
                 ):

        self.config = config
        self.data_values = data_values

        self.train_environment = train_environment
        self.valid_environment = valid_environment
        self.test_environment = test_environment

        self.train_environments = gym.vector.SyncVectorEnv([
            make_env("ts-regression-v0",
                     env_params=dict(copy.deepcopy(vars(self.train_environment)), seed=config['seed'] + i))
            for i in range(config['num_envs'])
        ])

        self.valid_environments = gym.vector.SyncVectorEnv([
            make_env("ts-regression-v0",
                     env_params=dict(copy.deepcopy(vars(self.valid_environment)), seed=config['seed'] + i))
            for i in range(1)
        ])

        self.agent = agent

        self.policy_optimizer = optim.AdamW(
            filter(lambda p: p.requires_grad, chain(agent.actor_mean.parameters(), [agent.actor_logstd])),
            lr=config['policy_learning_rate'], eps=1e-5)
        self.value_optimizer = optim.Adam(filter(lambda p: p.requires_grad, list(agent.critic.parameters())),
                                          lr=config['value_learning_rate'], eps=1e-5)

        self.logger = logger
        self.writer = writer
        self.wandb = wandb

        self.device = device

        self.global_step = 0
        self.check_index = 0
        self.save_index = 0
        self.quantile = config['quantile']
        self.quantile_loss = QuantileLoss(self.quantile)

        self.storage = self.set_storage()
        self.lookback_samples = []

    def set_storage(self):
        transition = self.config['transition']
        transition_shape = self.config['transition_shape']

        storage = TensorDict({}, batch_size=self.config['num_steps']).to(self.device)
        for name in transition:
            assert name in transition_shape
            shape = (self.config['num_steps'], *transition_shape[name]["shape"])
            type = transition_shape[name]["type"]
            storage[name] = build_storage(shape, type, self.device)

        return storage

    def flatten_storage(self, storage):
        flat_storage = {}
        for key, value in storage.items():
            flat_storage[key] = rearrange(value, 'b n ... -> (b n) ...')
        flat_storage = TensorDict(flat_storage, batch_size=self.config['num_steps'] * self.config['num_envs']).to(
            self.device)
        return flat_storage

    def explore_environment(self, init_state=None, init_info=None, reset=False):
        prefix = "train"

        if reset:
            state, info = self.train_environments.reset()
        else:
            state, info = init_state, init_info

        next_obs = torch.Tensor(state).to(self.device)
        next_done = torch.zeros(self.config['num_envs']).to(self.device)

        for step in range(self.config['num_steps']):
            self.global_step += self.config['num_envs']

            with torch.no_grad():
                prediction, logprob, _, value = self.agent.get_action_and_value(
                    torch.Tensor(next_obs[:, :, :-1]).to(self.device))
                self.storage["training_values"][step] = value.mean(dim=1).squeeze()

            self.storage['features'][step] = torch.Tensor(next_obs)
            self.storage['training_dones'][step] = next_done
            self.storage["training_actions"][step] = prediction.mean(dim=-1).unsqueeze(-1)
            self.storage["training_logprobs"][step] = logprob[:, 0]

            try:
                next_obs, reward, done, truncated, info = self.train_environments.step(prediction.cpu().numpy())
                self.lookback_samples.append((state, info["ground_truth"]))
            except BrokenPipeError as e:
                print(self.train_environments.step(prediction.cpu().numpy()))
                print(e)

            self.storage["training_rewards"][step] = torch.tensor(reward, dtype=torch.float32, device=self.device)
            self.storage["training_rewards"][step] = torch.clamp(self.storage["training_rewards"][step], -10,
                                                                 10)
            next_done = torch.Tensor(done).to(self.device)

            if len(prediction[0]) == 1:
                pred_value = prediction[0]
                quantile_num = 1
            elif len(prediction[0]) == 3:
                pred_value = prediction[0, 1]
                quantile_num = 3
            else:
                quantile_num = 0
                pred_value = None

            self.wandb.log({
                f"{prefix}/prediction": pred_value.item(),
                f"{prefix}/ground_truth": next_obs[:, -1, -1][0],
                f"{prefix}/upper_bound": info['upper'][0],
                f"{prefix}/lower_bound": info['lower'][0],
            })

            if "final_info" in info:
                for info_item in info["final_info"]:
                    if info_item is not None:
                        self.logger.info(f"global_step={self.global_step}, mse={info_item['mse']}")
                        self.writer.add_scalar(f"{prefix}/mse", info_item["mse"], self.global_step)
                        self.wandb.log({f"{prefix}/mse": info_item["mse"]})

                        self.logger.info(f"global_step={self.global_step}, mae={info_item['mae']}")
                        self.writer.add_scalar(f"{prefix}/mae", info_item["mae"], self.global_step)
                        self.wandb.log({f"{prefix}/mae": info_item["mae"]})

        # GAE
        with torch.no_grad():
            next_value = self.agent.get_value(torch.Tensor(next_obs[:, :, :-1]).to(self.device))
            last_gaelam = 0
            for t in reversed(range(self.config['num_steps'])):
                if t == self.config['num_steps'] - 1:
                    next_nonterminal = 1.0 - next_done
                    next_values = next_value[:, quantile_num // 2]
                else:
                    next_nonterminal = 1.0 - self.storage["training_dones"][t + 1]
                    next_values = self.storage["training_values"][t + 1]

                delta = self.storage["training_rewards"][t] + \
                        self.config['gamma'] * next_values * next_nonterminal - \
                        self.storage["training_values"][t]
                self.storage["training_advantages"][t] = last_gaelam = delta + \
                                                                       self.config['gamma'] * self.config[
                                                                           'gae_lambda'] * next_nonterminal * last_gaelam
            self.storage["training_returns"] = self.storage["training_advantages"] + self.storage["training_values"]

    def update_value(self, flat_storage, b_inds, info):

        batch_data = flat_storage[b_inds]

        for start in range(0, self.config['batch_size'], self.config['value_minibatch_size']):
            end = start + self.config['value_minibatch_size']
            mb_inds = b_inds[start:end]

            input_tensor = TensorDict({
                "features": flat_storage["features"][mb_inds],
                "actions": flat_storage["policy_actions"][mb_inds],
                "values": flat_storage["policy_values"][mb_inds],
            }, batch_size=self.config['value_minibatch_size']).to(self.device)

            old_values = batch_data["training_values"][mb_inds]  # old value functions
            returns = batch_data["training_returns"]  # ground truth

            new_values = self.agent.get_value(input_tensor['features'][:, :, :-1])
            quantile_num = new_values.shape[-1]

            if self.config['clip_vloss']:
                v_loss_unclipped = (new_values - returns[mb_inds]) ** 2

                v_clipped = old_values + torch.clamp(
                    new_values - old_values,
                    -self.config['clip_coef'],
                    self.config['clip_coef']
                )

                v_loss_clipped = (v_clipped - returns[mb_inds]) ** 2

                v_loss = 0.5 * torch.max(v_loss_unclipped, v_loss_clipped).mean()
            else:
                v_loss = 0.5 * ((new_values[:, quantile_num // 2] - returns[mb_inds]) ** 2).mean()  # MSE

            self.value_optimizer.zero_grad()
            v_loss.backward()
            nn.utils.clip_grad_norm_(self.agent.critic.parameters(), self.config['max_grad_norm'])
            self.value_optimizer.step()

        info["v_loss"] = v_loss.item()
        info["value"] = new_values.mean().item()

        return info

    def update_policy(self, flat_storage, b_inds, info):

        batch_data = flat_storage[b_inds]

        old_log_probs = batch_data["training_logprobs"]
        advantages = batch_data["training_advantages"]
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

        ground_truth = torch.Tensor(batch_data['features'][:, -1, -1]).unsqueeze(-1)
        predictions, log_probs, entropy, _ = self.agent.get_action_and_value(
            torch.Tensor(batch_data['features'])[:, :, :-1].to(self.device))

        logratio = log_probs[:, 0] - old_log_probs
        ratio = logratio.exp()
        approx_kl = ((ratio - 1) - logratio).mean().item()
        old_approx_kl = (0.5 * logratio.pow(2)).mean().item()

        clipped_ratio = torch.clamp(ratio, 1 - self.config['clip_coef'], 1 + self.config['clip_coef'])
        pg_loss1 = ratio * advantages
        pg_loss2 = clipped_ratio * advantages
        pg_loss = -torch.min(pg_loss1, pg_loss2).mean()

        if len(predictions[0]) == 1:
            pred_value = predictions
        elif len(predictions[0]) == 3:
            pred_value = predictions[:, 1].unsqueeze(-1)
        else:
            pred_value = None

        mse_loss = F.mse_loss(pred_value, ground_truth)
        quantile_loss = self.quantile_loss(predictions, ground_truth)


        clipfracs = torch.mean((torch.abs(ratio - 1.0) > self.config['clip_coef']).float()).item()
        entropy_loss = entropy.mean()
        loss = pg_loss - self.config['entropy_coef'] * entropy_loss + mse_loss + quantile_loss * self.config['quantile_coef']

        self.policy_optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(list(self.agent.actor_mean.parameters()) + [self.agent.actor_logstd],
                                       self.config['max_grad_norm'])
        self.policy_optimizer.step()

        kl_explode = approx_kl > self.config['target_kl']

        if "total_approx_kl" in info:
            total_approx_kl = info["total_approx_kl"] + approx_kl
        else:
            total_approx_kl = approx_kl

        res_info = {
            "mse_loss": mse_loss,
            "quantile_loss": quantile_loss,
            "clipfracs": clipfracs,
            "kl_explode": kl_explode,
            "old_approx_kl": old_approx_kl,
            "approx_kl": approx_kl,
            "total_approx_kl": total_approx_kl,
            "pg_loss": pg_loss,
            "entropy_loss": entropy_loss,
            "logprob": log_probs[:, 0],
        }

        return res_info

    def train(self):
        prefix = "train"
        start_time = time.time()

        state, info = self.train_environments.reset()

        for update in range(self.config['num_updates']):
            # Explore the environment to collect data
            assert self.train_environments is not None, "train_environments failed to be initialized"
            self.explore_environment(init_state=state, init_info=info, reset=False)
            flat_storage = self.flatten_storage(self.storage)

            # calculate `advantages`
            advantages = flat_storage["training_returns"] - flat_storage["training_values"]
            advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
            flat_storage["training_advantages"] = advantages

            records_info = {
                "clipfracs": [],
                "kl_explode": False,
                "policy_update_steps": 0,
                "mse_loss": torch.tensor(0),
                "quantile_loss": torch.tensor(0),
                "entropy_loss": torch.tensor(0),
                "old_approx_kl": torch.tensor(0),
                "approx_kl": torch.tensor(0),
                "total_approx_kl": torch.tensor(0),
                "v_loss": torch.tensor(0),
                "pg_loss": torch.tensor(0),
                "value": torch.tensor(0),
                "logprob": torch.tensor(0),
            }

            
            for epoch in range(self.config['num_epochs']):
                batch_indices = torch.randperm(flat_storage.batch_size[0]).split(self.config['batch_size'])
                for b_inds in batch_indices:
                    # update value (critic)
                    update_value_info = self.update_value(flat_storage, b_inds, records_info)
                    records_info.update(update_value_info)

                    # update policy (actor)
                    update_policy_info = self.update_policy(flat_storage, b_inds, records_info)
                    records_info.update(update_policy_info)

            # logger
            self.logger.info(
                f"Update {update}: Policy Loss={records_info['pg_loss']}, Value Loss={records_info['v_loss']}, "
                f"Entropy Loss={records_info['entropy_loss']}, KL={records_info['approx_kl']}, KL Explode={records_info['kl_explode']}"
            )
            self.writer.add_scalar(f"{prefix}/policy_learning_rate", self.policy_optimizer.param_groups[0]["lr"],
                                   self.global_step)
            self.writer.add_scalar(f"{prefix}/value_learning_rate", self.value_optimizer.param_groups[0]["lr"],
                                   self.global_step)
            self.writer.add_scalar(f"{prefix}/policy_loss", records_info["pg_loss"], self.global_step)
            self.writer.add_scalar(f"{prefix}/value_loss", records_info["v_loss"], self.global_step)
            self.writer.add_scalar(f"{prefix}/mse_loss", records_info["mse_loss"], self.global_step)
            self.writer.add_scalar(f"{prefix}/quantile_loss", records_info["quantile_loss"], self.global_step)
            self.writer.add_scalar(f"{prefix}/entropy_loss", records_info["entropy_loss"], self.global_step)
            self.writer.add_scalar(f"{prefix}/approx_kl", records_info["approx_kl"], self.global_step)
            self.writer.add_scalar(f"{prefix}/clipfracs", records_info["clipfracs"], self.global_step)

            wandb_dict = {
                f"{prefix}/policy_learning_rate": self.policy_optimizer.param_groups[0]["lr"],
                f"{prefix}/value_learning_rate": self.value_optimizer.param_groups[0]["lr"],
                f"{prefix}/mse_loss": records_info["mse_loss"],
                f"{prefix}/quantile_loss": records_info["quantile_loss"],
                f"{prefix}/policy_loss": records_info["pg_loss"],
                f"{prefix}/value_loss": records_info["v_loss"],
                f"{prefix}entropy_loss": records_info["entropy_loss"],
                f"{prefix}/approx_kl": records_info["approx_kl"],
                f"{prefix}/clipfracs": records_info["clipfracs"],
            }

            self.wandb.log(wandb_dict)
            self.logger.info(f"SPS: {self.global_step}, {(time.time() - start_time)}")

            if not os.path.exists(self.config['checkpoint_path']):
                os.makedirs(self.config['checkpoint_path'])

            if self.global_step % self.config['check_steps'] >= self.check_index:
                self.valid(self.global_step)
                self.check_index += 1

            if self.global_step % self.config['save_steps'] >= self.save_index:
                torch.save(self.agent.state_dict(), os.path.join(self.config['checkpoint_path'], "{:08d}.pth".format(
                    self.global_step // self.config['save_steps'])))
                self.save_index += 1

        self.valid(self.global_step)

        torch.save(self.agent.state_dict(), os.path.join(self.config['checkpoint_path'], "{:08d}.pth".format(
            self.global_step // self.config['save_steps'] + 1)))

        self.train_environments.close()
        self.valid_environments.close()
        
        self.writer.close()
        self.wandb.finish()

    def valid(self, global_step):
        prefix = "valid"
        lookback, retrain_gap, update_alpha_gap = self.config['lookback'], self.config['lookback'], self.config[
            'update_alpha_gap']

        state, info = self.valid_environments.reset()
        next_obs = torch.Tensor(state).to(self.device)

        mse_list, mae_list, rmse_list = [], [], []
        predictions, ground_truths, interval_lens = [], [], []
        upper_bounds = []
        lower_bounds = []

        steps = []
        for step in range(self.config['num_steps']):
            steps.append(step)
            window_preds = []
            window_gts = []
            with torch.no_grad():
                for _ in range(self.config['output_window_length']):
                    prediction = self.agent.get_action(next_obs[:, :, :-1].to(self.device))
                    if len(prediction[0]) == 1:
                        pred_value = prediction
                    elif len(prediction[0]) == 3:
                        pred_value = prediction[:, 1]

                    window_preds.append(pred_value.cpu().numpy())
                    window_gts.append(next_obs[:, -1, -1].unsqueeze(-1).cpu().numpy())

                    next_obs, reward, done, truncated, info = self.valid_environments.step(prediction.cpu().numpy())
                    next_obs = torch.Tensor(next_obs).to(self.device)

                    self.wandb.log({
                        f"{prefix}/prediction":pred_value[0].item(),
                        f"{prefix}/ground_truth": next_obs[:, -1, -1][0].unsqueeze(-1).cpu().numpy(),
                        f"{prefix}/upper_bound": info['upper'][0],
                        f"{prefix}/lower_bound": info['lower'][0],
                        f"{prefix}/interval_len": info['interval_len'][0],
                        f"{prefix}/alpha_t": info['alpha_t'][0],
                        f"{prefix}/q_hat": info['q_hat'][0],
                    })

                # calculate MSE, MAE, RMSE
                window_preds = np.concatenate(window_preds, axis=0)
                window_gts = np.concatenate(window_gts, axis=0)
                mse_w = np.mean((window_preds - window_gts) ** 2)
                mae_w = np.mean(np.abs(window_preds - window_gts))
                rmse_w = np.sqrt(mse_w)
                mse_list.append(mse_w)
                mae_list.append(mae_w)
                rmse_list.append(rmse_w)

            predictions.append(window_preds[-1])
            ground_truths.append(window_gts[-1])
            interval_lens.append(info['interval_len'])
            upper_bounds.append(info['upper'])
            lower_bounds.append(info['lower'])

        predictions = np.array(predictions).flatten()
        ground_truths = np.concatenate(ground_truths, axis=0)
        upper_bounds = np.concatenate(upper_bounds, axis=0)
        lower_bounds = np.concatenate(lower_bounds, axis=0)

        fig, ax = plt.subplots(figsize=(10, 5))
        ax.plot(steps, ground_truths, label="Ground Truth", color="green")
        ax.plot(steps, predictions, label="Prediction", color="blue")
        ax.fill_between(steps, lower_bounds, upper_bounds, color="blue", alpha=0.2, label="Uncertainty Bounds")
        ax.legend()
        ax.set_title("Prediction vs Ground Truth with Uncertainty")

        # self.wandb.load_image("Prediction plt", fig)
        if not os.path.exists(self.config['fig_path']):
            os.makedirs(self.config['fig_path'])
        plt.savefig(f"{self.config['fig_path']}" + f"prediction_{global_step}.png")
        plt.close(fig)

        mse_loss = sum(mse_list) / len(mse_list)
        mae_loss = sum(mae_list) / len(mae_list)
        rmse_loss = sum(rmse_list) / len(rmse_list)
        len_interval = np.mean(interval_lens)

        condition = np.zeros(len(upper_bounds), dtype=bool)

        for i in range(len(upper_bounds)):
            in_range = np.logical_and(
                ground_truths[i] >= lower_bounds[i],
                ground_truths[i] <= upper_bounds[i]
            )
            condition[i] = in_range

        coverage_rate = condition.sum() / len(upper_bounds)

        self.logger.info(f"Validation: MSE={mse_loss}, MAE={mae_loss}, RMSE={rmse_loss}")
        self.logger.info(f"Interval Length={len_interval}, Coverage Rate={coverage_rate}")

        self.writer.add_scalar(f"{prefix}/mse", mse_loss, global_step)
        self.writer.add_scalar(f"{prefix}/mae", mae_loss, global_step)
        self.writer.add_scalar(f"{prefix}/rmse", rmse_loss, global_step)
        self.wandb.log({
            f"{prefix}/mse": mse_loss,
            f"{prefix}/mae": mae_loss,
            f"{prefix}/rmse": rmse_loss,
            f"{prefix}/coverage_rate": coverage_rate,
        })
